# -*- coding: utf-8 -*-
"""
@date: 2020/12/18 20:48
@file: linera_stu.py
@author: lilong
@desc: nn.Linear简单使用
"""

import torch
import torch.nn as nn


# in_features由输入张量的形状决定，out_features则决定了输出张量的形状
connected_layer = nn.Linear(in_features=4 * 4 * 2, out_features=1)

# 假定输入的图像形状为[64,64,3]
input = torch.randn(1, 4, 4, 2)

# 将四维张量转换为二维张量之后，才能作为全连接层的输入
input = input.view(1, 4 * 4 * 2)
print(input.shape)
output = connected_layer(input)  # 调用全连接层
print(output.shape)